One method for cluster analysis is the non-parametric k - Nearest Neighbors algorithm. This supervised learning method classifies an unknown data point by assuming that an unknown point’s class will be similar to nearby points. The model begins by calculating the distance (often euclidean, but can be other measures) to the k = 1, 2, … n nearest data points. After identifying the nearest neighbors, the model then counts the number of different class types and classifies the unknown point based on the plurality of class votes. For example, in a k=3 model, if the 3 nearest points are class 1, class 1, and class 2, the model classifies the unknown point as class 1. This modeling technique is a straightforward classification technique that can be applied to many problems.
This tutorial serves as an introduction to the k-Nearest Neighbors classification technique and covers:
Replication Requirements: What you’ll need to reproduce the analysis in this tutorial.
Train Test Split: How to separate the data frame into a Training set and a Testing set.
Modeling Using knn: How to use the knn function from the class package.
Refining the model: How to refine the model using best subset with the leaps package.
Selecting the best k: Select the level of K that gives you the best test set accuracy values.
Points of Note: Points of note regarding using k-NN
Learn More: Visit these resources to learn more about k -NN.
To reproduce the results below you will need to download the genders dataset located at the following link. Make sure you save this file as voicedata.RData. Additionally, the tidyverse, plotly, class, knitr, kableExtra, and leap packages will be used for the analysis.
After saving, load it into your enviroment with load("voicedata.RData")1. We begin by inspecting the dataset.
load("voicedata.RData")
| meanfreq | sd | median | Q25 | Q75 | IQR | skew | kurt | sp.ent | sfm | mode | centroid | meanfun | minfun | maxfun | meandom | mindom | maxdom | dfrange | modindx | label |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0.060 | 0.064 | 0.032 | 0.015 | 0.090 | 0.075 | 12.86 | 274.40 | 0.893 | 0.492 | NA | 0.060 | 0.084 | 0.016 | 0.276 | 0.008 | 0.008 | 0.008 | NA | NA | male |
| 0.066 | 0.067 | 0.040 | 0.019 | 0.093 | 0.073 | 22.42 | 634.61 | 0.892 | 0.514 | NA | 0.066 | 0.108 | 0.016 | 0.250 | 0.009 | 0.008 | 0.055 | 0.047 | 0.053 | male |
| 0.077 | 0.084 | 0.037 | 0.009 | 0.132 | 0.123 | 30.76 | 1024.93 | 0.846 | 0.479 | NA | 0.077 | 0.099 | 0.016 | 0.271 | 0.008 | 0.008 | 0.016 | 0.008 | 0.047 | male |
| 0.151 | 0.072 | 0.158 | 0.097 | 0.208 | 0.111 | 1.23 | 4.18 | 0.963 | 0.727 | 0.084 | 0.151 | 0.089 | 0.018 | 0.250 | 0.201 | 0.008 | 0.562 | 0.555 | 0.247 | male |
| 0.135 | 0.079 | 0.125 | 0.079 | 0.206 | 0.127 | 1.10 | 4.33 | 0.972 | 0.784 | 0.104 | 0.135 | 0.106 | 0.017 | 0.267 | 0.713 | 0.008 | 5.484 | 5.477 | 0.208 | male |
| 0.133 | 0.080 | 0.119 | 0.068 | 0.210 | 0.142 | 1.93 | 8.31 | 0.963 | 0.738 | 0.113 | 0.133 | 0.110 | 0.017 | 0.254 | 0.298 | 0.008 | 2.727 | 2.719 | 0.125 | male |
The data set consists of 3168 rows and 21 columns. Additionally, scanning the first few lines shows a couple of NA values. Let’s continue analysis by removing any column with an NA value.
voice <- voice[ , apply(voice, 2, function(x) !any(is.na(x)))]
This reduces the dataset to 3168 rows and 18 columns. Notice that the number of observations hasn’t changed, but we have fewer columns.
Next, we need to separate the data into a training and testing data set for calculations. The primary purpose of this separation is to allow for us to evaluate the model’s accuracy by evaluating new data points that weren’t used to generate the model. We do this transforming the label column into a column of factors and then separating into a training set with an equivalent proportion of factors to the original dataset with the function below
voice$label <- as.factor(voice$label)
train_test_split <- function(data, groupby_column = ncol(data), samplesize = 0.75, seed = 1) {
factors <- levels(data[[groupby_column]])
train <- data.frame()
test <- data.frame()
train_test_split <- list()
for (factor in factors) {
#Separate original data frame by factors
datafactor <- data %>%
filter(data[,groupby_column] == factor)
#set Random Number Generator Seed from input for reproducibility
set.seed(seed)
smp_size <- floor(samplesize * nrow(datafactor))
#build the training and test sets off the randomly generated indices from above
train_ind <- sample(seq_len(nrow(datafactor)), size = smp_size)
trainfactor <- datafactor[train_ind,]
testfactor <- datafactor[-train_ind,]
#combine the train/test set for each factor into an overall list
train <- rbind(train,trainfactor)
test <- rbind(test,testfactor)
}
train_test_split[[1]] <- as.data.frame(train)
train_test_split[[2]] <- as.data.frame(test)
return(train_test_split)
}
tts <- train_test_split(data = voice, groupby_column = 18, seed = 123, samplesize = 0.75)
train <- tts[[1]]
test <- tts[[2]]
class packageThe knn function from the class package takes four primary arguments:
To get our two data frames into the appropriate format, we segregate the known classification labels from the data frame.
#separate the two sets to meet KNN function requirements (i.e. remove class labels)
train_class <- train$label
train_nolabel <- train %>%
subset(select = -label) %>%
as.data.frame()
test_class <- test$label
test_nolabel <- test %>%
subset(select = -label) %>%
as.data.frame()
we then buid our knn model using the function knn
knn.pred <- knn(train = train_nolabel, test = test_nolabel, cl = train_class, k = 1)
The knn.pred object now contains all of the information about our kNN model that we need. We can provide a confusion matrix and calculate the accuracy rate with the following:
accuracy <- mean(knn.pred==test_class)*100 %>% round(digits = 3)
table(knn.pred,test_class) %>%
kable() %>%
kable_styling(bootstrap_options = c("bordered", "striped", "hover"), full_width = F, position = 'center') %>%
add_header_above(c("Predicted " = 1, "Actual" = 2))
| female | male | |
|---|---|---|
| female | 273 | 112 |
| male | 123 | 284 |
which has an accuracy of 70.33%.
However, we’d like to generate a higher classification accuracy if possible. I think that we are probably introducing or accounting for too much noise in our model, so let’s see if we can reduce the variable set and have a better outcome.
leapsWe will use the leaps package and the regsubsets function on our original data setto try to reduce the number of kept variables. If you are interested in learning more about variable subset reduction, be sure to check out the linear model selection tutorial. While we’re at it, let’s limit the maximum number of variables considered to two so that we can plot the results and see what’s happening visually.
First, do the best subset selection with the command best_subset <- regsubsets(label ~ ., voice, nvmax = 1)
best_subset <- regsubsets(label ~ ., voice, nvmax = 1)
summary(best_subset)
## Subset selection object
## Call: regsubsets.formula(label ~ ., voice, nvmax = 1)
## 17 Variables (and intercept)
## Forced in Forced out
## meanfreq FALSE FALSE
## sd FALSE FALSE
## median FALSE FALSE
## Q25 FALSE FALSE
## Q75 FALSE FALSE
## skew FALSE FALSE
## kurt FALSE FALSE
## sp.ent FALSE FALSE
## sfm FALSE FALSE
## meanfun FALSE FALSE
## minfun FALSE FALSE
## maxfun FALSE FALSE
## meandom FALSE FALSE
## mindom FALSE FALSE
## maxdom FALSE FALSE
## IQR FALSE FALSE
## centroid FALSE FALSE
## 1 subsets of each size up to 2
## Selection Algorithm: exhaustive
## meanfreq sd median Q25 Q75 IQR skew kurt sp.ent sfm centroid
## 1 ( 1 ) " " " " " " " " " " " " " " " " " " " " " "
## 2 ( 1 ) " " " " " " " " " " "*" " " " " " " " " " "
## meanfun minfun maxfun meandom mindom maxdom
## 1 ( 1 ) "*" " " " " " " " " " "
## 2 ( 1 ) "*" " " " " " " " " " "
The best subset selection shows that the IQR and meanfun features are the best two-feature combination available.
We’ll plot these two variables against each other to see if there are any clear clusters in the data set
ggplot() +
geom_jitter(data = train, aes(x = meanfun, y = IQR, color = label)) +
geom_jitter(data = test, aes(x = meanfun, y = IQR)) +
xlab("Mean Fundamental Frequency") +
ylab("Interquartile Frequency Range") +
ggtitle(label = "Scatterplot of Mean Fundamental Frequency and IQR",
subtitle = "Black Points are test set points") +
theme(plot.title = element_text(hjust = 0.5))
The scatter plot shows two clear clusters for male and female speech samples and the majority of our test points (black markers) fall within one of these clusters
Therefore, using these two variables, we will regenerate the kNN model and see if the accuracy improves.
keep <- c("iqr","meanfun")
test_nolabel <- test_nolabel %>%
select(IQR, meanfun)
train_nolabel <- train_nolabel %>%
select(IQR, meanfun)
Again, build the k NN model.
knn.pred = knn(train = train_nolabel, test = test_nolabel, cl = train_class, k = 1)
accuracy1 <- mean(knn.pred==test_class)*100 %>% round(digits = 3)
table(knn.pred,test_class) %>%
kable(caption = "Confusion Matrix") %>%
kable_styling(bootstrap_options = c("bordered", "striped", "hover"), full_width = F, position = 'center') %>%
add_header_above(c("Predicted" = 1, "Actual" = 2))
| female | male | |
|---|---|---|
| female | 381 | 17 |
| male | 15 | 379 |
By reducing our feature set, we’ve raised our accuracy to 95.96%, an improvement of over 25.63%
For each of the models above, k was aritrarily selected to be 1. This, as stated in the introduction, means that the unknown test point will be classified based on the nearest data points. However, a different k may be more appropriate in ensuring that the model is not too flexible or inflexible. To do this, we’ll iterate through a loop of possible k values to determine which k is the best.
averageaccuracy = vector()
#pick how many k to try
k = 1:100
for (i in 1:length(k)) {
#build a model for the i-th k-value
knn.pred <- knn(train = train_nolabel, test = test_nolabel, cl = train_class, k = i)
#calculate the accuracy
averageaccuracy[i] = mean(knn.pred==test_class)
}
averageaccuracy <- averageaccuracy*100
bestk = which.max(averageaccuracy)
bestaccuracy = max(averageaccuracy)
By finding the best accuracy from the list, we see that the best k for this data set is 9 with an accuracy of 9772.727%. We can validate that by plotting the average accuracy vector as k increases.
p <- averageaccuracy %>%
as.tibble %>%
ggplot() +
geom_line(aes(x= k, y = averageaccuracy)) +
xlab("Number of k") +
ylab("Accuracy (%)") +
scale_x_continuous(breaks = seq(0,length(k),5) ) +
ggtitle("Test Set Accuracy as k Increases") +
theme(plot.title = element_text(hjust = 0.5))
ggplotly(p = ggplot2::last_plot(), dynamicTicks = F)
# averageaccuracy %>%
# as.tibble() %>%
# plot_ly(x = ~1:num_k, y = ~averageaccuracy, type = 'scatter', mode = 'lines')
The plot shows that as k increases, the plot reaches a maximum at k = 9
We’ve shown that k-NN can be a fairly effective classification method. However, there are certain things to understand and account for when conducting classification using this method.
The data input works with numerical inputs. For non-numerical inputs dummy variables can be created to transform factors into binary (1 or 0) variables, but this can greatly increase the dimensionality of the problem.
Distances calculations are affected by scale. For example, a dataset with one variable on the order of millions and another on the order of tens will overvalue differences in the first variable at the expense of the second. This can be overcome by standardizing (subtract the mean, divide by standard error) all of the input variables, but analysts should account for the additional computational effort required.
This classification technique can be computationally expensive on very large datasets. Unlike, say a regression equation where a new point is merely put into a function and a value is returned, new data points in a k-NN model are compared to every other point in the data set. For extremely large datasets, this may be prohibitive for effective modeling.
This tutorial will help you learn the basics about k Nearest Neighbors. To learn more visit the Introduction to Statistical Learning in R webpage.
Using the iris data set:
k=1?k is best?